from jax import grad
import jax.numpy as jnp


def tanh(x):
    y = jnp.exp(-2.0 * x)
    return (1.0 - y) / (1.0 + y)


grad_tanh = grad(tanh)
print(grad_tanh(1.0))

grad3_tanh = grad(grad(grad_tanh))
print(grad3_tanh(1.0))


def abs_val(x):
    if x > 0:
        return x
    else:
        return -x

abs_val_grad = grad(abs_val)
print(abs_val_grad(1.0))
print(abs_val_grad(-1.0))
